{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot import Collection\n",
    "from cot.stats import evaluation_as_table"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/kon/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b1f3a30a0d714d1faea29358e44dceb1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c2bad80a2ecd4d6cbc965dbeb242bd4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9bf61d46f8dd4d4e864c80b9fe3bd0e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c92ce5e0fb84c50b8599f7760f270ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "53541ca5940347d19a00f73ccfb9cf48",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8479e1ba13eb448f9be2f01ca077c471",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll = Collection.load_thoughtsource_100(\"all\",load_pregenerated_cots=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = coll.select(\"all\", 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| commonsense_qa | -       | 1       | -      |\n",
       "| med_qa         | -       | -       | 1      |\n",
       "| medmc_qa       | -       | 1       | -      |\n",
       "| open_book_qa   | -       | -       | 1      |\n",
       "| strategy_qa    | 1       | -       | -      |\n",
       "| worldtree      | -       | -       | 1      |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.unload_datasets([\"commonsense_qa\", \"med_qa\", \"medmc_qa\", \"open_book_qa\", \"strategy_qa\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name      | Train   | Valid   |   Test |\n",
       "|-----------|---------|---------|--------|\n",
       "| worldtree | -       | -       |      1 |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'commonsense_qa', 'entailment_bank', 'gsm8k', 'mawps', 'med_qa', 'medmc_qa', 'open_book_qa', 'pubmed_qa', 'qed', 'strategy_qa', 'svamp']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Configuration of the input and parameters of the language model \n",
    "# config={\n",
    "#     \"instruction_keys\": \"qa-08\",\n",
    "#     \"cot_trigger_keys\": None,\n",
    "#     \"answer_extraction_keys\": 'auto-kojima', \n",
    "#     \"author\" : \"thoughtsource\",\n",
    "#     \"api_service\": \"openai_chat\",\n",
    "#     \"api_time_interval\": 1,\n",
    "#     \"engine\": \"gpt-3.5-turbo\", \n",
    "#     \"temperature\": 0,\n",
    "#     \"max_tokens\": 512,\n",
    "#     \"verbose\": False,\n",
    "#     \"warn\": False,\n",
    "# }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": \"qa-08\",\n",
    "    \"cot_trigger_keys\": None,\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"cohere\",\n",
    "    \"api_time_interval\": 1,\n",
    "    \"engine\": \"command-xlarge-nightly\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": \"qa-08\",\n",
    "    \"cot_trigger_keys\": None,\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"api_time_interval\": 1,\n",
    "    \"engine\": \"gpt-4-0314\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b41a8d16382c42518bb67856392b4856",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.generate(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "61de9d344ce2434b8076ae21a170a184",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'worldtree': {'test': {'accuracy': {'gpt-4': {'qa-08_None_kojima-A-D': 1.0}}}}}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7b3720e58694ad593d41c033d0ef5e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.dump(\"foo_now\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# join strings from a list with underscore\n",
    "def join_strings(list_of_strings):\n",
    "    if list_of_strings is None:\n",
    "        list_of_strings = [\"None\"]\n",
    "    if isinstance(list_of_strings, str):\n",
    "        list_of_strings = [list_of_strings]\n",
    "    joined_string = \"\"\n",
    "    for string in list_of_strings:\n",
    "        joined_string += (\"-\" + str(string))\n",
    "    # delete first underscore\n",
    "    joined_string = joined_string[1:]\n",
    "    return(joined_string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"thoughtsource_100\" + \"_\" + config['api_service'] + \"_\" + config['engine'].replace(\"/\", \"_\") + \"_\" + join_strings(config[\"instruction_keys\"]) + \"_\" + join_strings(config[\"cot_trigger_keys\"]) + \".json\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### merge into"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100 = Collection.from_json(\"thoughtsource_100.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100 = ts_100.merge(coll)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100.dump(\"thoughtsource_100.json\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.from_json(\"thoughtsource_100.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.select_generated_cots(author='thoughtsource', model=[\"gpt-3.5-turbo\", \"flan-T5-xxl\"], cot_trigger=None, instruction=None)\n",
    "coll.select_generated_cots(author='thoughtsource')\n",
    "coll.select_generated_cots(model='gpt-3.5-turbo')\n",
    "# coll.select_generated_cots(cot_trigger=[None,'zhou-01'])\n",
    "# coll.select_generated_cots(instruction=[None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()\n",
    "table = evaluation_as_table(eval)\n",
    "table"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
